/-
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Wojciech Różowski
-/

module

prelude

public import Lean.Elab.PreDefinition.PartialFixpoint
public import Lean.Elab.Tactic.Rewrite
public import Lean.Meta.Tactic.Simp
public import Lean.Linter.UnusedVariables
namespace Lean.Elab.Command
open Lean Meta Elab

builtin_initialize
  registerTraceClass `Elab.coinductive

/-
  This file contains the main bits of the implementation of `coinductive` keyword.
  The main entry point is the `elabCoinductive`.

  At the beginning, elaboration of mutual blocks where some definitions are defined via
  `coinductive` keyword is the same as of `inductive`. However, in the `elabInductives` we
  elaborate views, as they were `inductive` types, but just before replacing the free variables
  with constants and adding it to the kernel, we call `mkFlatInductive` that rewrites the inductives
  to the "flat" form, that is we add parameters for each of the definitions in the clique
  and replace recursive calls in constructors with these parameters. For example, the following definition

  ```
  variable (α : Type)
  coinductive infSeq (r : α → α → Prop) : α → Prop where
  | step : r a b → infSeq r b → infSeq r a
  ```

  yields the following "flat" inductive:
  ```
  inductive infSeq._functor (r : α → α → Prop) (infSeq._functor.call : α → Prop) : α → Prop where
  | step : r a b → infSeq._functor.call b → infSeq r a
  ```

  Upon such rewrite, the code for adding flat inductives does not diverge much from the usual
  way its done for inductive declarations, but we omit applying attributes/modifiers and
  we do not set the syntax references to track those declarations (as this is auxillary piece of
  data hidden from the user).

  Then, upon adding such flat inductives for each definition in the mutual block to the environment,
  we use `Meta.MkIffOfInductiveProp` machinery to rewrite those to predicates made of disjunctions
  and existentials that we will refer to as "existential" form. This form makes it easy to generate
  user-readable coinduction proof principles and allows to use existing `monotonicity` tactic.

  For example, the above flat inductive corresponds to:
  ```
    def infSeq._functor.existential : (α : Type) → (α → α → Prop) → (α → Prop) → α → Prop :=
      fun α r infSeq._functor.call a => ∃ b, r a b ∧ infSeq._functor.call b
  ```

  Both forms are connected through the following lemma (that is generated by
  `Meta.MkIffOfInductive`) machinery:
  ```
    infSeq._functor.existential_equiv (α : Type) (r : α → α → Prop)
      (infSeq._functor.call : α → Prop) (a✝ : α) :
      infSeq._functor α r infSeq._functor.call a✝ ↔ ∃ b, r a✝ b ∧ infSeq._functor.call b
  ```

  Those definitions are used to populate `PreDefinition`s that are then passed to `PartialFixpoint`
  machinery.

  At that stage all predicates (if definitions are monotone) are added to the environment.
  Note that at this point `PartialFixpoint` machinery applies the attributes and modifiers. We
  use the syntax references from the original `InductiveView`s and set them to those declarations.

  Moreover, we have following theorem (generated by `generateEqLemmas`) that connects the coinductive
  predicate to its flat inductive:
  ```
    info: infSeq.functor_unfold (α : Type) (r : α → α → Prop) (a✝ : α) : infSeq α r a✝ = infSeq._functor α r (infSeq α r) a✝
  ```

  We use these to define all the constructors from the original definition. For example, we obtain:
  ```
    infSeq.step (α : Type) (r : α → α → Prop) {a b : α} : r a b → infSeq α r b → infSeq α r a
  ```

  Similarly, we obtain the associated `casesOn` lemma (that are generated by `mkCasesOnCoinductive`):
  ```
    infSeq.casesOn (α : Type) (r : α → α → Prop) {motive : (a : α) → infSeq α r a → Prop} {a✝ : α} (t : infSeq α r a✝)
    (step : ∀ {a b : α} (a_1 : r a b) (a_2 : infSeq α r b), motive a (infSeq.step α r a_1 a_2)) : motive a✝ t
  ```

  At the very end, we make use of the syntax references from the original `InductiveView`s
  and set them to newly generated constructors. We apply deriving handlers and docstrings.
  Note that attributes and modifiers are handled earlier by `PartialFixpoint` machinery
-/

/-- This structure contains the data carried in `InductiveElabStep1` that are solely used in
mutual coinductive predicate elaboration. -/
public structure CoinductiveElabData where
  /-- Declaration Id from the original `InductiveView` -/
  declId : Syntax
  /-- Declaration name of the predicate-/
  declName : Name
  /-- Ref from the original `InductiveView`-/
  ref : Syntax
  /-- Modifiers from the original `InductiveView`-/
  modifiers : Modifiers
  /-- Constructor refs from the original `InductiveView`-/
  ctorSyntax : Array Syntax
  /-- The flag that is `true` if the predicate was defined via `coinductive` keyword and `false`
  otherwise. When we elaborate a mutual definition, we allow mixing `coinductive` and `inductive`
  keywords, and hence we need to record this information.
  -/
  isGreatest : Bool
  deriving Inhabited


public def addFunctorPostfix : Name → Name := (·  ++ `_functor)

public def removeFunctorPostfix : Name → Name := (Name.modifyBase · Name.getPrefix)

public def removeFunctorPostfixInCtor : Name → Name :=
  fun | Name.str p s => Name.str (removeFunctorPostfix p) s
      | _ => panic! "UnexpectedName"

private def rewriteGoalUsingEq (goal : MVarId) (eq : Expr) (symm : Bool := false) : MetaM MVarId := do
  let rewriteResult ← goal.rewrite (←goal.getType) eq symm
  goal.replaceTargetEq rewriteResult.eNew rewriteResult.eqProof

/--
  Generates unfolding lemmas that relate coinductive predicates to their flat inductive forms.
  These lemmas are essential for the constructor generation process, providing the bridge
  between the user-facing coinductive predicates and their internal flat representations.

  Example:

  Given a definition:
  ```
  coinductive infSeq (r : α → α → Prop) : α → Prop where
  | step : r a b → infSeq r b → infSeq r a
  ```

  We generate the following unfolding lemma:
  ```
    infSeq.functor_unfold (α : Type) (r : α → α → Prop) (a✝ : α) : infSeq α r a✝ = infSeq._functor α r (infSeq α r) a✝
  ```
-/
private def generateEqLemmas (infos : Array InductiveVal) : MetaM Unit := do
  let levels := infos[0]!.levelParams.map mkLevelParam
  for info in infos do
    let res ← forallTelescopeReducing info.type fun args _ => do
      let params := args[:info.numParams - infos.size]
      let args := args[info.numParams:]

      let lhs := mkConst (removeFunctorPostfix info.name) levels
      let lhs := mkAppN lhs params
      let lhs := mkAppN lhs args

      let calls := infos.map fun info => mkAppN (mkConst (removeFunctorPostfix info.name) levels) params
      let rhs := mkConst info.name levels
      let rhs := mkAppN rhs (params ++ calls ++ args)

      let goalType ← mkEq lhs rhs
      let goal ← mkFreshExprMVar goalType

      let goalMVarId := goal.mvarId!

      let .some #[fixEq] ←  getEqnsFor? (removeFunctorPostfix info.name) | throwError "did not generate unfolding theorem"
      let existentialEquiv := mkConst (info.name ++ `existential_equiv) levels

      let mut fixEq := mkConst fixEq levels
      fixEq := mkAppN fixEq params
      for arg in args do
        fixEq ← mkCongrFun fixEq arg

      let newGoal ← rewriteGoalUsingEq goalMVarId existentialEquiv
      newGoal.assign fixEq

      let goal ← instantiateMVars goal
      mkLambdaFVars (params ++ args) goal
    trace[Elab.coinductive] "res: {res}"
    addDecl <|
      .defnDecl <|
        ←mkDefinitionValInferringUnsafe
          (name := (removeFunctorPostfix info.name) ++ `functor_unfold)
          (levelParams := info.levelParams)
          (type := (←inferType res))
          (value := res)
          (hints := .opaque)

/--
  Generates a constructor for a coinductive predicate that corresponds to constructors
  in the original `InductiveView`.

  The process:
  1. Takes the flat inductive constructor type
  2. Fills recursive call parameters with the actual coinductive predicates
  3. Converts to existential form using the equivalence lemma
  4. Applies the unfolding rule to get the final constructor form
-/
private def generateCoinductiveConstructor (infos : Array InductiveVal) (ctorSyntax : Syntax)
    (numParams : Nat) (name : Name) (ctor : ConstructorVal) : TermElabM Unit := do
  trace[Elab.coinductive] "Generating constructor: {removeFunctorPostfixInCtor ctor.name}"
  let numPreds := infos.size
  let predNames := infos.map fun val => removeFunctorPostfix val.name
  let levelParams := infos[0]!.levelParams.map mkLevelParam
  /-
    We start by looking at the type of the constructor of the flat inductive and then by introducing
    all its parameters to the scope.
  -/
  forallBoundedTelescope ctor.type (numParams + numPreds) fun args body => do
    /-
      The first `numParams` many items of `args` are parameters from the original definition,
      while the remaining ones are free variables that correspond to recursive calls.
    -/
    let params := args.take numParams
    let predFVars := args[numParams:]
    /-
      We will fill recursive calls in the body with the just defined (co)inductive predicates.
    -/
    let mut predicates : Array Expr := predNames.map (mkConst · levelParams)
    predicates := predicates.map (mkAppN · params)
    let body := body.replaceFVars predFVars predicates
    /-
      Now, we look at the rest of the constructor.
      We start by collecting its non-parameter premises, as well as inspecting its conclusion.
    -/
    let res ← forallTelescope body fun fields bodyExpr => do
      /-
        First, we look at conclusion and pick out all arguments that are non-parameters.
      -/
      let bodyAppArgs := bodyExpr.getAppArgs[numParams + infos.size:]
      /-
        The goal (i.e. right hands side of a constructor) that we are trying to make is just
        the coinductive predicate with parameters and non-parameter arguments applied.
      -/
      let goalType := mkConst (removeFunctorPostfix name) levelParams
      let mut goalType := mkAppN goalType params
      goalType := mkAppN goalType bodyAppArgs
      trace[Elab.coinductive] "The conclusion of the constructor {ctor.name} is {goalType}"

      -- We start by making the metavariable for it, that we will fill
      let goal ← mkFreshExprMVar <| .some goalType
      let hole := Expr.mvarId! goal
      let unfoldEq := mkConst ((removeFunctorPostfix name) ++ `functor_unfold) levelParams
      let unfoldEq := mkAppN unfoldEq params

      let rewriteResult ← hole.rewrite (←hole.getType) unfoldEq

      let newHole ← hole.replaceTargetEq rewriteResult.eNew rewriteResult.eqProof

      /-
        Now, all it suffices is to call an appropriate constructor of the flat inductive.
      -/
      let constructor := mkConst ctor.name levelParams
      let constructor := mkAppN constructor params
      let constructor := mkAppN constructor predicates
      let constructor := mkAppN constructor fields
      newHole.assign constructor
      let conclusion ← instantiateMVars goal
      let conclusion ← mkLambdaFVars fields conclusion
      mkLambdaFVars params conclusion
    let type ← inferType res
    trace[Elab.coinductive] "The elaborated constructor is of the type: {type}"
    /-
      We finish by registering the appropriate declaration
    -/
    addDecl <|
      .defnDecl <|
        ←mkDefinitionValInferringUnsafe
          (name := removeFunctorPostfixInCtor ctor.name)
          (levelParams := ctor.levelParams)
          (type := type)
          (value := res)
          (hints:= .opaque)
    Term.addTermInfo' ctorSyntax res (isBinder := true)

/--
  Given the number of parameters and the `InductiveVal` containing flat inductives
  (see `mkFlatInductive`) and `CoinductiveElabData` associated with the mutual coinductive
  predicates, generates their constructors that correspond to the
  constructors given in the original syntax.
-/
private def generateCoinductiveConstructors (numParams : Nat) (infos : Array InductiveVal)
    (coinductiveElabData : Array CoinductiveElabData) : TermElabM Unit := do
  for indType in infos, e in coinductiveElabData do
    for ctor in indType.ctors, ctorSyntax in e.ctorSyntax do
      generateCoinductiveConstructor infos ctorSyntax numParams indType.name
        <| ←getConstInfoCtor ctor

/--
  Generates `casesOn` eliminators for coinductive predicates.
  These eliminators allow pattern matching on coinductive predicates,
  enabling case analysis in proofs.
-/
private def mkCasesOnCoinductive (infos : Array InductiveVal) : MetaM Unit := do
  let levels := infos[0]!.levelParams.map mkLevelParam
  let allCtors := infos.flatMap (·.ctors.toArray)

  forallBoundedTelescope infos[0]!.type (infos[0]!.numParams - infos.size) fun params _ => do
  let predicates := infos.map fun info => mkConst (removeFunctorPostfix info.name) levels
  let predicates := predicates.map (mkAppN · params)
  for info in infos do
    let casesOnName := Lean.mkCasesOnName info.name
    let casesOnInfo ← getConstInfo casesOnName
    let originalCasesOn ← mkConstWithLevelParams casesOnName
    let originalCasesOn := mkAppN originalCasesOn (params ++ predicates)

    let goalTypeWithParamsApplied ← inferType originalCasesOn
    /-
      We replace the mentions of the flat inductive with a coinductive predicate
      and replace all constructors of the original type.
    -/
    let goalTypeWithParamsApplied := goalTypeWithParamsApplied.replace (fun e =>
      if e.isApp then
        let bodyArgs := e.getAppArgs[info.numParams:]
        if e.isAppOf info.name then
          mkAppN (mkConst (removeFunctorPostfix info.name) levels) <| params ++ bodyArgs
        else
          if allCtors.any e.isAppOf then
            let bodyArgs := e.getAppArgs[info.numParams:]
            mkAppN (mkConst (removeFunctorPostfixInCtor (e.getAppFn.constName)) levels)
              <| params ++ bodyArgs
          else none
      else
        none
    )
    -- The type of `casesOn` of the flat inductive, upon having the parameters applied
    let originalType ← inferType originalCasesOn
    -- The equivalence proof, that will be used in the subsequent rewrites
    let eqProof := mkConst ((removeFunctorPostfix info.name) ++ `functor_unfold) levels
    /-
      First, we look at the motive. We construct a free variable `motive`
      of the type of motive, as it appears in the `goalTypeWithParamsApplied`
    -/
    forallBoundedTelescope goalTypeWithParamsApplied (.some 1) fun args goalType => do
      let #[motive] := args
        | throwError "Expected one argument"
      /-
        Similarly, we pull of the type of the motive, as it appears in the `casesOn`
        of the flat inductive. We then make an mvar of this type and try to
        fill it using `motive` fvar.
      -/
      let (Expr.forallE _ type _ _) := originalType
        | throwError "expected to be quantifier"
      let motiveMVar ← mkFreshExprMVar type
      /-
        We intro all the indices and the occurence of the coinductive predicate
      -/
      let (fvars, subgoal) ← motiveMVar.mvarId!.introN (info.numIndices + 1)
      subgoal.withContext do
        let lastAssumption := fvars[fvars.size -1]!

        -- We perform the rewrite at the hypothesis
        let rewriteTarget := (←getLCtx).get! lastAssumption
        let rewriteTarget := rewriteTarget.type
        let rewriteResult ← subgoal.rewrite rewriteTarget eqProof (symm := true)
        let replacementResult ← subgoal.replaceLocalDecl lastAssumption
              rewriteResult.eNew rewriteResult.eqProof

        let newFVars := fvars.modify (fvars.size - 1) fun _ => replacementResult.fvarId
        let (_, afterReplacing) ← replacementResult.mvarId.revert newFVars
        -- Now it is in the form that we can assign the `motive` fvar to the goal
        afterReplacing.assign motive
        -- Then we apply the metavariable to the `casesOn` of the flat inductive
        let originalCasesOn := mkApp originalCasesOn motiveMVar

        -- The next arguments of the `casesOn` are type indices
        forallBoundedTelescope goalType info.numIndices fun indices goalType => do
          /-
            The types do not change, so we just make free variables for them
            and apply them to the `casesOn` of the flat inductive
          -/
          let originalCasesOn := mkAppN originalCasesOn indices
          /-
            The next argument is the occurence of the coinductive predicate.
            The original `casesOn` of the flat inductive mentions it in
            unrolled form, so we need to rewrite it.
          -/
          forallBoundedTelescope goalType (.some 1) fun targetArgs _ => do
            /-
              We again make a free variable of the type as it appears in the desired
              type of `casesOn` for the coinductive predicate.
            -/
            let #[target] := targetArgs | throwError "Expected one argument"
            /-
              Then, we fish out the type as it appears in the `casesOn` of the flat
              inductive, then making a metavariable for it.
            -/
            let (Expr.forallE _ type _ _) ← inferType originalCasesOn
              | throwError "expected to be quantifier"
            let targetMVar ← mkFreshExprMVar type
            let targetMVarSubgoal ← rewriteGoalUsingEq targetMVar.mvarId! eqProof (symm := true)
            targetMVarSubgoal.assign target
            -- Upon performing the rewrite, we apply the mvar to the flat inductive `casesOn`
            let originalCasesOn := mkApp originalCasesOn targetMVar

            let originalCasesOn ←
              mkLambdaFVars (params ++ args ++ indices ++ targetArgs) originalCasesOn
            let originalCasesOn ← instantiateMVars originalCasesOn

            let levelParams := casesOnInfo.levelParams
            let casesOnName := mkCasesOnName (removeFunctorPostfix info.name)
            let casesOnType ← mkForallFVars params goalTypeWithParamsApplied
            addDecl <|
              .defnDecl <|
                ← mkDefinitionValInferringUnsafe
                    (name := casesOnName)
                    (levelParams := levelParams)
                    (type := casesOnType)
                    (value := originalCasesOn)
                    (hints := .opaque)
            -- We apply the attribute so that the `cases` tactic can pick it up
            liftCommandElabM
              <| liftTermElabM
                <| Term.applyAttributes
                    casesOnName #[{name := `cases_eliminator}, {name := `elab_as_elim}]

/--
  Main entry point for elaborating mutual coinductive predicates. This function is called after
  generating a flat inductive and adding it to the environment.

  We look at corresponding existential form of the flat inductive (see `Meta.MkIffOfInductiveProp`),
  use it to populate `PreDefinition`s that correspond to the predicates, and then we call
  the `PartialFixpoint` machinery to register them as (co)inductive predicates.

  Finally, we generate constructors for each of the predicates, that correspond to the constructors
  that were given by the user.
-/
public def elabCoinductive (coinductiveElabData : Array CoinductiveElabData) : TermElabM Unit := do
  trace[Elab.coinductive] "Elaborating: {coinductiveElabData.map (·.declName)}"
  let infos ← coinductiveElabData.mapM (getConstInfoInduct ·.declName)
  let levelParams := infos[0]!.levelParams.map mkLevelParam
  /-
    We infer original names and types of the predicates.
    To get such names, we need to remove `._functor` postfix. At the same time,
    we need to forget about the parameters for recursive calls, to get the original types.
  -/
  let originalNumParams := infos[0]!.numParams - infos.size
  let namesAndTypes : Array (Name × Expr) ← infos.mapM fun info => do
    let type ← forallTelescope info.type fun args body => do
      mkForallFVars (args[:originalNumParams] ++ args[info.numParams:]) body
    return (removeFunctorPostfix (info.name), type)
  /-
    We make dummy constants that are used in populating PreDefinitions
  -/
  let consts := namesAndTypes.map fun (name, _) => (mkConst name levelParams)
  /-
    We create values of each of PreDefinitions, by taking existential (see `Meta.SumOfProducts`)
    form of the associated flat inductives and applying paramaters, as well as recursive calls
    (with their parameters passed).
  -/
  let preDefVals ← forallBoundedTelescope infos[0]!.type originalNumParams fun params _ => do
    infos.mapM fun info => do
      let mut functor := mkConst (info.name ++ `existential) levelParams
      functor ← unfoldDefinition functor
      functor := mkAppN functor <| params ++ consts.map (mkAppN · params)
      mkLambdaFVars params functor
  /-
    Finally, we populate the PreDefinitions
  -/
  let preDefs : Array PreDefinition := preDefVals.mapIdx fun idx defn =>
    { ref := coinductiveElabData[idx]!.ref
      binders := coinductiveElabData[idx]!.ref
      kind := .def
      levelParams := infos[0]!.levelParams
      modifiers := coinductiveElabData[idx]!.modifiers
      declName := namesAndTypes[idx]!.1
      type := namesAndTypes[idx]!.2
      value := defn
      termination := {
        ref := coinductiveElabData[idx]!.ref
        terminationBy?? := .none
        terminationBy? := .none
        partialFixpoint? := .some {
            ref := coinductiveElabData[idx]!.ref
            term? := .none
            fixpointType := if coinductiveElabData[idx]!.isGreatest then
                              .coinductiveFixpoint else .inductiveFixpoint
        }
        decreasingBy? := .none
        extraParams := 0
      }
    }
  partialFixpoint (← getLCtx, ← getLocalInstances) preDefs
  generateEqLemmas infos
  generateCoinductiveConstructors originalNumParams infos coinductiveElabData
  mkCasesOnCoinductive infos

end Lean.Elab.Command
